package com.baomidou.mybatisplus.ext.util;

import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.GlobalConfigUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionUtils;

import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE;

public class Funtions {

    private static Pattern humpPattern = Pattern.compile("[A-Z]");

    /**
     * 把参数添加到wrapper 里去  todo 还没弄
     *
     * @param queryWrapper wrapper
     * @param operator     操作符 比如like
     * @param po           po对象
     * @param property     属性名
     * @param value        值
     */
    public static void addWrapper(QueryWrapper queryWrapper, String operator, Object po, String property, Object value) {
        String field = null;
        try {
            field = getDBField(po.getClass(), property);
        } catch (Throwable e) {
            e.printStackTrace();
            return;
        }
        switch (operator) {
            case "eq":
                queryWrapper.eq(field, value);
                break;
            case "lt":
                queryWrapper.lt(field, value);
                break;
            case "gt":
                queryWrapper.gt(field, value);
                break;
            case "le":
                queryWrapper.le(field, value);
                break;
            case "ge":
                queryWrapper.ge(field, value);
                break;
            case "ne":
                queryWrapper.ne(field, value);
                break;
            case "like":
                queryWrapper.like(field, value);
                break;
            case "likeLeft":
                queryWrapper.likeLeft(field, value);
                break;
            case "likeRight":
                queryWrapper.likeRight(field, value);
                break;
            case "notLike":
                queryWrapper.notLike(field, value);
                break;
            case "isNull":
                queryWrapper.isNull(field);
                break;
            case "notNull":
                queryWrapper.isNotNull(field);
                break;
            case "in":
                queryWrapper.in(field, (Collection) (value));
                break;
            case "orderByAsc":
                queryWrapper.orderByAsc(field);
                break;
            case "orderByDesc":
                queryWrapper.orderByDesc(field);
                break;
        }
    }

    /**
     * 把参数添加到wrapper 里去  todo 还没弄
     *
     * @param queryWrapper wrapper
     * @param operator     操作符 比如like
     * @param po           po对象
     * @param property     属性名
     * @param minValue     最小值
     * @param maxValue     最大值
     */
    public static void addWrapper(QueryWrapper queryWrapper, String operator, Object po, String property, Object minValue, Object maxValue) {
        String field = null;
        try {
            field = getDBField(po.getClass(), property);
        } catch (Throwable e) {
            e.printStackTrace();
            return;
        }
        switch (operator) {
            case "between":
                queryWrapper.between(field, minValue, maxValue);
                break;
            case "notBetween":
                queryWrapper.notBetween(field, minValue, maxValue);
                break;
        }
    }

    /**
     * 新建对象
     *
     * @param poClassName po的类名字
     * @return
     */
    public static Object newOBJ(String poClassName) {
        try {
            return Class.forName(poClassName).newInstance();
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        return null;
    }



    /**
     * 使用wrapper进行查询
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static List list(QueryWrapper wrapper, Object po) {
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put("ew", wrapper);
        try {
            return sqlSession.selectList(sqlStatement(SqlMethod.SELECT_LIST, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }

    /**
     * 使用wrapper进行查询
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static List list(QueryWrapper wrapper, Object po, Object fields) {
        try {
            addSelect(wrapper, po, fields);
        } catch (Throwable e) {
            e.printStackTrace();
            return new ArrayList();
        }
        return list(wrapper, po);
    }

    /**
     * wrapper.select 执行查询字段
     *
     * @param wrapper
     * @param po
     * @param fields
     * @throws Throwable
     */
    public static void addSelect(QueryWrapper wrapper, Object po, Object fields) throws Throwable {
        String[] fieldsArr = (String[]) fields;
        List<String> dbFields = new ArrayList<>();
        for (String field : fieldsArr) {
            dbFields.add(getDBField(po.getClass(), field));
        }
        wrapper.select(dbFields.toArray(new String[dbFields.size()]));
    }


    /**
     * 获取数据库字段名
     *
     * @param entityClass
     * @param fieldName
     * @return
     * @throws Throwable
     */
    private static String getDBField(Class entityClass, String fieldName) throws Throwable {
        Field field = getDeclaredField(entityClass, fieldName);
        //如果是id则取id的value
        if(field.isAnnotationPresent(TableId.class) && StringUtils.isNotEmpty(field.getAnnotation(TableId.class).value())){
            return field.getAnnotation(TableId.class).value();
        }
        if(field.isAnnotationPresent(TableField.class) && StringUtils.isNotEmpty(field.getAnnotation(TableField.class).value())){
            return field.getAnnotation(TableField.class).value();
        }
        return humpToLine(fieldName);
    }



    /**
     * 驼峰转下划线
     * @param str  字符串
     * @return
     */
    public static String humpToLine(String str) {
        Matcher matcher = humpPattern.matcher(str);
        StringBuffer sb = new StringBuffer();
        while (matcher.find()) {
            matcher.appendReplacement(sb, "_" + matcher.group(0).toLowerCase());
        }
        matcher.appendTail(sb);
        return sb.toString();
    }

    /**
     * 使用wrapper进行查询
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static Object one(QueryWrapper wrapper, Object po) {
        List list = list(wrapper, po);
        if (list.size() > 0) {
            return list.get(0);
        }
        return null;
    }

    /**
     * 使用wrapper进行查询
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static Object one(QueryWrapper wrapper, Object po, Object fields) {
        List list = list(wrapper, po, fields);
        if (list.size() > 0) {
            return list.get(0);
        }
        return null;
    }

    /**
     * 使用wrapper进行查询
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static Long count(QueryWrapper wrapper, Object po) {
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put("ew", wrapper);
        try {
            return sqlSession.selectOne(sqlStatement(SqlMethod.SELECT_COUNT, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }

    /**
     * 使用wrapper进行修改返回受影响行数
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static Integer update(QueryWrapper wrapper, Object po) {
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put("ew", wrapper);
        map.put("et", po);
        try {
            return sqlSession.update(sqlStatement(SqlMethod.UPDATE, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }

    /**
     * 使用wrapper进行删除返回受影响行数
     *
     * @param wrapper
     * @param po
     * @return
     */
    public static Integer delete(QueryWrapper wrapper, Object po) {
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put("ew", wrapper);
        try {
            return sqlSession.delete(sqlStatement(SqlMethod.DELETE, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }


    protected static SqlSession sqlSession(Class poClass) {
        return SqlHelper.sqlSession(poClass);
    }

    protected static String sqlStatement(SqlMethod sqlMethod, Class poClass) {
        return sqlStatement(sqlMethod.getMethod(), poClass);
    }

    protected static String sqlStatement(String sqlMethod, Class poClass) {
        return SqlHelper.table(poClass).getSqlStatement(sqlMethod);
    }

    protected static void closeSqlSession(SqlSession sqlSession, Class poClass) {
        SqlSessionUtils.closeSqlSession(sqlSession, GlobalConfigUtils.currentSessionFactory(poClass));
    }

    /**
     * 循环向上转型, 获取对象的 DeclaredField
     *
     * @param clazz     : 子类对象
     * @param fieldName : 父类中的属性名
     * @return 父类中的属性对象
     */

    public static Field getDeclaredField(Class<?> clazz, String fieldName) {
        Field field = null;

        for (; clazz != Object.class; clazz = clazz.getSuperclass()) {
            try {
                field = clazz.getDeclaredField(fieldName);
                return field;
            } catch (Exception e) {
                // 这里甚么都不要做！并且这里的异常必须这样写，不能抛出去。
                // 如果这里的异常打印或者往外抛，则就不会执行clazz = clazz.getSuperclass(),最后就不会进入到父类中了

            }
        }

        return null;
    }
}
